package thant.sqlgear;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Method;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.SQLException;
import java.sql.Savepoint;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

import javax.sql.DataSource;

import org.apache.commons.dbutils.QueryRunner;
import org.apache.commons.dbutils.handlers.MapListHandler;

import thant.common.map.CommonMap;
import thant.common.map.SQLMap;
import static thant.sqlgear.SQL.*;

/**
 * @ClassName: SQLRunner
 * @Description: SQL语句执行工具
 * @author: 肖文峰
 * @date: 2019年5月30日 下午3:33:27
 */
public class SQLRunner {
	private DataSource _ds = null;
	private SilentConnection connect = null;
	private QueryRunner runner = new QueryRunner();
	
	private String lastSql = null;
	private Object[] lastArgs = {};
	
	private String scriptroot = "";
	
	private Object logger = null;
	private Method logfunc = null;
	
	private String forceFieldCase = "none"; //none:默认不强制 lower:强制小写 upper:强制大写
	
	public SQLRunner() {
	}

	/**
	 * @Title 构造函数
	 * @Description 根据连接对象创建
	 * @param cnt 连接对象
	 */
	public SQLRunner(Connection cnt) {
		if (cnt != null) {
			connect = new SilentConnection(cnt);
			//connect.setAutoCommit(false); 对外部给的连接，不应该有默认动作(不挖坑)，所以注释掉
		}
	}

	/**
	 * @Title isEmpty
	 * @Description 
	 * @param s
	 * @return boolean
	 * @throws
	 */
	private static boolean isEmpty(String s) {
		return null==s || "".equals(s);
	}
	
	/**
	 * @Title 构造函数
	 * @Description 根据数据源对象创建
	 * @param ds 数据源对象
	 */
	public SQLRunner(DataSource ds) {
		Connection cnt = null;
		try {
			cnt = ds.getConnection();
			for (int i=0; i<3; ++i) {
				if (cnt.isValid(0)) {
					break;
				} else {
					Connection cnt2 = ds.getConnection();
					cnt.close();
					cnt = cnt2;
					if (logfunc != null) {
						try {
							logfunc.invoke(logger, "连接无效");
						} catch (Exception e1) { e1.printStackTrace(); }
					}
				}
			}
		} catch (SQLException e) {
			e.printStackTrace();
		}
		
		if (cnt != null) {
			connect = new SilentConnection(cnt);
			connect.setAutoCommit(false);
			_ds = ds; 
		}
	}
	
	public String getLastSql() {
		return lastSql;
	}

	public Object[] getLastArgs() {
		return lastArgs;
	}

	public void setLastArgs(Object[] lastArgs) {
		this.lastArgs = lastArgs;
	}

	protected void finalize() {
		close();
	}
	
	/**
	 * @Title 关闭连接
	 * @Description 
	 * @return void
	 * @throws
	 */
	public void close() {
		if (connect != null) {
			try {
				if (_ds != null) {
					connect.commit();//自己从数据源获取的连接，自己提交且释放;
					connect.close(); //_ds==null时表示是外部给的连接，由外部释放
				}
			} catch (Exception e) {} finally {
				connect = null;
				_ds = null;
			}
		}
	}

	/**
	 * @Title 插入/更新一条记录
	 * @Description 如果在参数mao中提供了where条件且只匹配一条记录则做更新，否则做插入操作
	 * @param cnt 连接对象，NULL表示使用构造函数关联的连接
	 * @param map 提交的记录，类型 @see SQLMap
	 * @return int 修改的记录条数
	 * @throws
	 */
	@SuppressWarnings("unchecked")
	public int update(Connection cnt, Map<String, Object> map) {
		boolean doUpdate = false;
		boolean hasCondition = (map.get(SQLMap.whereskey) != null
			&& ((Map<String, Object>)map.get(SQLMap.whereskey)).size()>0);
		
		String action = (String)map.get(SQLMap.updatetypekey);
		if ("UPDATE".equals(action)) {
			doUpdate = true;
		} else if ("INSERT".equals(action)) {
			doUpdate = false;
		} else if (hasCondition) { //REPLACE
			//执行REPLACE，而且查询条件不为空才做原记录查询，避免发生全表查询
			Map<String, Object> newone = new HashMap<String, Object>();
			newone.putAll(map);
			newone.put(SQLMap.fieldskey, "COUNT(*)");
			newone.remove(SQLMap.orderbykey);
			long num = 0;
			try {
				num = Long.valueOf(String.valueOf(CommonMap.getValue(query(cnt, newone))));
			} catch (Exception e) {
				throw new RuntimeException(e.toString());
			}
			if (0 == num) {
				doUpdate = false;
			} else if (1 == num) {
				//查到有原记录，可做更新
				doUpdate = true;
			} else {
				throw new RuntimeException("不允许用REPLACE批量更新多条记录，请使用UPDATE");
			}
		}
		if (doUpdate) {
			if (hasCondition) {
				return update(cnt, "UPDATE", map.get(SQLMap.objectskey)
					,setWithMap(map)
					,whereWithMap(map));
			} else {
				throw new RuntimeException("禁止无条件的全表更新,如果要全表更新,请显性设置条件如1=1");
			}
		} else {
			return update(cnt, "INSERT INTO", map.get(SQLMap.objectskey)
				,IF(false == (Boolean)map.get(SQLMap.insertwithlistkey), fieldWithMap(map))
				,valueWithMap(map));
		}
	}

	/**
	 * @Title 删除一条记录
	 * @Description 应用于表的CRUD，只能删除一条记录
	 * @param cnt 连接对象，NULL表示使用构造函数关联的连接
	 * @param map 定义删除对象的SQLmap
	 * @return int 删除的记录数，=1表示成功，0未找到记录
	 * @throws
	 */
	@SuppressWarnings("unchecked")
	public int delete(Connection cnt, Map<String, Object> map) {
		if (map.get(SQLMap.whereskey) != null
			&& ((Map<String, Object>)map.get(SQLMap.whereskey)).size()>0) {
			//查询条件不为空才做原记录查询，避免发生全表删除
			int rows = update(cnt, "DELETE FROM", map.get(SQLMap.objectskey)
				,whereWithMap(map));
			if (rows>1) {
				throw new RuntimeException("不能删除多条记录");
			} else if (1 == rows) {
				return 1; 
			}
		}
		return 0;
	}

	/**
	 * @Title 通用更新
	 * @Description 支持UPDATE/INSERT/DELETE等改变数据的动作
	 * @param cnt 连接对象，NULL表示使用构造函数关联的连接
	 * @param sqlA 构成SQL语句的变长参数
	 * @return int 影响的记录数
	 * @throws
	 */
	public int update(Connection cnt, Object... sqlA) {
		return update(cnt, new SQL(sqlA));
	}
	
	/**
	 * @Title  获得SQL日志，用于输出日志
	 * @Description
	 * @param sqlbody SQL语句
	 * @param args SQL参数
	 * @param cput 开始执行时间
	 * @return String 返回执行日志
	 * @throws
	 */
	private final static String _logprefix = "   ";
	private String getCommonSQLLog(String sqlbody, Object[] args, long cput, Object ret, String errmsg) {
		long endt = System.currentTimeMillis();
		StringBuilder sb = new StringBuilder(), tmpsb = new StringBuilder();
		
		sb.append("\r\n<<<----------------------------SQLRunner start----------------------------\r\n");

		StackTraceElement[] stackTraceElements= Thread.currentThread().getStackTrace();
		for (int i=2; i<stackTraceElements.length; ++i) {
			if (!stackTraceElements[i].getClassName().startsWith("thant.") || i == stackTraceElements.length-1) {
				sb.append(_logprefix).append('[').append(stackTraceElements[i].getClassName())
					.append('.').append(stackTraceElements[i].getMethodName())
					.append(":").append(stackTraceElements[i].getLineNumber())
					.append(']');
				break;
			}
		}
		
		sb.append("\r\n").append(_logprefix).append(sqlbody);
		sb.append("\r\n").append(_logprefix).append("参数:[");

		for (Object item : args) {
			tmpsb.append(',').append(null == item ? "null" : item.toString());
		}
		if (tmpsb.length()>0) {
			sb.append(tmpsb.substring(1));
		}
		sb.append("]\r\n").append(_logprefix);
		
		if (null == errmsg) {
			//打印结果
			sb.append("cost ").append(endt - cput).append(" ms. ");
			if (ret instanceof Integer) {
				sb.append("updated rows:").append(ret);
			} else {
				@SuppressWarnings("unchecked")
				List<Map<String, Object>> lst = (List<Map<String, Object>>)ret;

				sb.append("result count:").append(lst.size());
				for (int i=0; i<lst.size(); ++i) {
					Map<String, Object> map = lst.get(i);
					tmpsb.setLength(0);
					for (Entry<String, Object> item : map.entrySet()) {
						tmpsb.append(',').append(item.getKey()).append(":").append(item.getValue());
					}
					sb.append("\r\n").append(_logprefix).append("{");
					if (tmpsb.length()>0) {
						sb.append(tmpsb.substring(1));
					}
					sb.append("}");
				}
			}
		} else {
			sb.append(errmsg);
		}
		
		sb.append("\r\n----------------------------------------------------------------------->>>");

		return sb.toString();
	}
	
	/**
	 * @Title 通用更新
	 * @Description 支持UPDATE/INSERT/DELETE等改变数据的动作
	 * @param cnt 连接对象，NULL表示使用构造函数关联的连接
	 * @param sql 定义的SQL对象 @see SQL
	 * @return int 影响的记录数
	 * @throws
	 */
	public int update(Connection cnt, SQL sql) {
		if (null == cnt && null == connect) return 0;

		lastSql  = sql.getSql();
		lastArgs = sql.getArgs();

		int rows = 0;
		long cput = 0;
		String errmsg = null;
		try {
			if (logfunc != null) {
				cput = System.currentTimeMillis();
			}
			rows = runner.update(null==cnt ? connect.getConnection() : cnt, lastSql, lastArgs);
		} catch (SQLException e) {
			errmsg = e.getMessage();
		} finally {
			if (logfunc != null) {
				try {
					logfunc.invoke(logger, this.getCommonSQLLog(lastSql, lastArgs, cput, rows, errmsg));
				} catch (Exception e1) { e1.printStackTrace(); }
			}
			if (errmsg != null) {
				throw new RuntimeException(errmsg);
			}
		}
		return rows;
	}

	/**
	 * @Title Map条件查询
	 * @Description 返回带分页和排序的单/多行数据 
	 * @param cnt 连接对象，NULL表示使用构造函数关联的连接
	 * @param map 查询+分页+排序条件，@see SQLMap
	 * @return List<Map<String,Object>> 返回数据结果集
	 * @throws
	 */
	public List<Map<String, Object>> query(Connection cnt, Map<String, Object> map) {
		String fields = (String)map.get(SQLMap.fieldskey);
		if (isEmpty(fields)) fields = "*";
		
		return query(null==cnt ? connect : cnt,
			"SELECT", fields, "FROM", map.get(SQLMap.objectskey),
			whereWithMap(map),
			IF(!isEmpty((String)map.get(SQLMap.orderbykey)), "ORDER BY", map.get(SQLMap.orderbykey)),
			IF(map.get(SQLMap.pagebeginkey)!=null && map.get(SQLMap.pageendkey)!=null,
				"LIMIT", V(map.get(SQLMap.pagebeginkey)) , ",", V(map.get(SQLMap.pageendkey)))
		);
	}

	/**
	 * @Title 通用查询
	 * @Description   
	 * @param cnt 连接对象，NULL表示使用构造函数关联的连接
	 * @param sqlA 构成SQL语句的变长参数
	 * @return List<Map<String,Object>> 返回数据结果集
	 * @throws
	 */
	public List<Map<String, Object>> query(Connection cnt, Object... sqlA) {
		return query(cnt, new SQL(sqlA));
	}
	
	/**
	 * @Title 通用查询
	 * @Description 
	 * @param cnt 连接对象，NULL表示使用构造函数关联的连接
	 * @param sql 定义的SQL对象 @see SQL
	 * @return List<Map<String,Object>> 返回数据结果集
	 * @throws
	 */
	public List<Map<String, Object>> query(Connection cnt, SQL sql) {
		if (null == cnt && null == connect || null == sql) return null;
		
		lastSql  = sql.getSql();
		lastArgs = sql.getArgs();
		
		List<Map<String, Object>> ret = null;
		long cput = 0;
		String errmsg = null;
		try {
			if (logfunc != null) {
				cput = System.currentTimeMillis();
			}
			ret = runner.query(null==cnt ? connect.getConnection() : cnt, lastSql, new MapListHandler(), lastArgs);
			if (ret != null && ret.size()>0 && !"none".equals(forceFieldCase)) {
				Map<String, Object> sour = ret.get(0);
				List<String> colA = new ArrayList<String>();
				List<String> newcolA = new ArrayList<String>();
				Iterator<Entry<String, Object>> it = sour.entrySet().iterator();
				while (null != it && it.hasNext()) {
					Entry<String, Object> item = it.next();
					String col = item.getKey();
					String newcol = "lower".equals(forceFieldCase) ? col.toLowerCase() : col.toUpperCase();
					if (!newcol.equals(col)) {
						colA.add(col);
						newcolA.add(newcol);
					}
				}
				if (colA.size()>0) {
					for (int i=0; i<ret.size(); ++i) {
						sour = ret.get(i);
						for (int j=0; j<colA.size(); ++j) {
							String col = colA.get(j);
							String newcol = newcolA.get(j);
							Object val = sour.remove(col);
							sour.put(newcol, val);
						}
					}
				}
			}
		} catch (SQLException e) {
			errmsg = e.getMessage();
		} finally {
			if (logfunc != null) {
				try {
					logfunc.invoke(logger, this.getCommonSQLLog(lastSql, lastArgs, cput, ret, errmsg));
				} catch (Exception e1) { e1.printStackTrace(); }
			}
			if (errmsg != null) {
				throw new RuntimeException(errmsg);
			}
		}
		return ret;
	}

	/**
	 * @Title 查询返回一个值
	 * @Description 把结果中第一行的第一列数据返回，如果没有的话返回NULL
	 * @param cnt 连接对象，NULL表示使用构造函数关联的连接
	 * @param sqlA 构成SQL语句的变长参数
	 * @return Object 查询到的值
	 * @throws
	 */
	public Object queryValue(Connection cnt, Object... sqlA) {
		return CommonMap.getValue(query(cnt, sqlA));
	}

	/**
	 * @Title 查询COUNT(*)
	 * @Description 把结果中第一行的第一列数据转换为long返回，失败抛出异常
	 * @param cnt 连接对象，NULL表示使用构造函数关联的连接
	 * @param sqlA 构成SQL语句的变长参数
	 * @return long 查询到的值
	 * @throws
	 */
	public long queryCount(Connection cnt, Object... sqlA) {
		Object val = CommonMap.getValue(query(cnt, sqlA));
		if (null == val) return 0;
		if (val instanceof Long) return (long)val;
		if (val instanceof Integer) return (int)val;
		return Long.valueOf(String.valueOf(val));
	}

	/**
	 * @Title 查询返回一行数据
	 * @Description 
	 * @param cnt 连接对象，NULL表示使用构造函数关联的连接
	 * @param sqlA 定义的SQL对象 @see SQL
	 * @return Map<String,Object>
	 * @throws
	 */
	public Map<String, Object> queryOne(Connection cnt, Object... sqlA) {
		return CommonMap.getOne(query(cnt, sqlA));
	}

	/**
	 * @Title 返回构造函数关联的连接
	 * @Description 
	 * @return SilentConnection 连接对象 @see SilentConnection
	 * @throws
	 */
	public SilentConnection getConnect() {
		return connect;
	}
	
	/**
	 * @Title 返回数据库类型
	 * @Description 
	 * @return String 数据库产品名称
	 * @throws
	 */
	public String getDBType() {
		if (null != connect) {
			DatabaseMetaData dbmd = connect.getMetaData();
			try {
				return dbmd.getDatabaseProductName();
			} catch (SQLException e) {
				e.printStackTrace();
			}
		}
		return "";
	}

	/**
	 * @Title 提交连接
	 * @Description 
	 * @return void
	 * @throws
	 */
	public void commit() {
		if (connect != null) {
			connect.commit();
		}
	}
	
	/**
	 * @Title 回滚连接
	 * @Description 
	 * @param sp 要回滚到的保存点，没有保存点可传NULL
	 * @return void
	 * @throws
	 */
	public void rollback(Savepoint sp) {
		if (connect != null) {
			if (null == sp) {
				connect.rollback();
			} else {
				connect.rollback(sp);
			}
		}
	}
	
	public String getScriptroot() {
		return scriptroot;
	}

	public void setScriptroot(String scriptroot) {
		this.scriptroot = scriptroot;
	}

	private String getScriptPath(String scriptname) {
		return new StringBuilder(scriptroot)
			.append(File.separator)
			.append(scriptname)
			.toString();
	}
	
	public Object execute(Connection cnt, String filename) {
		InputStream stm = null;
		try {
			stm = new FileInputStream(getScriptPath(filename));
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		} finally {
			if (stm != null) {
				try { stm.close(); } catch (IOException e) { }
				stm = null;
			}
		}
		return null;
	}

	public Object getLogger() {
		return logger;
	}

	public void setLogger(Object obj, String level) {
		try {
			Class<?> cls;
			if (obj instanceof String) {
				//静态方法
				cls = Class.forName((String)obj);
				this.logger = null;
			} else {
				//实例方法
				cls = obj.getClass();
				this.logger = obj;
			}
			logfunc = cls.getMethod(level, String.class);
		} catch (Exception e) {
			e.printStackTrace();
		}
	}

	public String getForceFieldCase() {
		return forceFieldCase;
	}

	public void setForceFieldCase(String force) {
		if (force.equalsIgnoreCase("lower")) {
			this.forceFieldCase = "lower";
		} else if (force.equalsIgnoreCase("upper")) {
			this.forceFieldCase = "upper";
		} else if (force.equalsIgnoreCase("none")) {
			this.forceFieldCase = "none";
		}
	}
}
